import copy
import torch
from torch import nn


def extract_name_kwargs(obj):
    if isinstance(obj, dict):
        obj = copy.copy(obj)
        name = obj.pop("name")
        kwargs = obj
    else:
        name = obj
        kwargs = {}

    return (name, kwargs)


def get_norm_layer(norm, features):
    name, kwargs = extract_name_kwargs(norm)

    if name is None:
        return nn.Identity(**kwargs)

    if name == "layer":
        return nn.LayerNorm((features,), **kwargs)

    if name == "batch":
        return nn.BatchNorm2d(features, **kwargs)

    if name == "instance":
        return nn.InstanceNorm2d(features, **kwargs)

    raise ValueError("Unknown Layer: '%s'" % name)


def get_norm_layer_fn(norm):
    return lambda features: get_norm_layer(norm, features)


def get_activ_layer(activ):
    name, kwargs = extract_name_kwargs(activ)

    if (name is None) or (name == "linear"):
        return nn.Identity()

    if name == "gelu":
        return nn.GELU(**kwargs)

    if name == "relu":
        return nn.ReLU(inplace=True, **kwargs)

    if name == "leakyrelu":
        return nn.LeakyReLU(inplace=True, **kwargs)

    if name == "tanh":
        return nn.Tanh()

    if name == "sigmoid":
        return nn.Sigmoid()

    raise ValueError("Unknown activation: '%s'" % name)


def select_activation(activation):
    return get_activ_layer(activation)


def select_optimizer(parameters, optimizer):
    name, kwargs = extract_name_kwargs(optimizer)

    if name == "AdamW":
        return torch.optim.AdamW(parameters, **kwargs)

    if name == "Adam":
        return torch.optim.Adam(parameters, **kwargs)

    raise ValueError("Unknown optimizer: '%s'" % name)


def select_loss(loss):
    name, kwargs = extract_name_kwargs(loss)

    if name.lower() in ["l1", "mae"]:
        return nn.L1Loss(**kwargs)

    if name.lower() in ["l2", "mse"]:
        return nn.MSELoss(**kwargs)

    raise ValueError("Unknown loss: '%s'" % name)
